{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hswish, swish and HSigmoid Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from collections import namedtuple\n",
    "# from pyraul.nn import MLP\n",
    "# from pyraul.pipeline import accuracy\n",
    "# from pyraul.pipeline.train_step import train_step\n",
    "from pyraul.tools.dataset import Dataset\n",
    "from pyraul.tools.dumping import dump_weights\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Activation functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Swish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Swish(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "    def forward(self, x):\n",
    "        return x*self.sigmoid(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hard Swish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HSwish(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return x*F.relu6(x+3, inplace=True)/6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hard Sigmoid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HSigmoid(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return F.relu6(x+3, inplace=True)/6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "## Comparation h-swish vs swish and ReLU, Sigmoid vs h-sigmoid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "#### Forward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.from_numpy(np.arange(-8.0, 8.0, 0.5, dtype=np.float32))\n",
    "relu_result = nn.ReLU().eval()(x)\n",
    "swish_result = Swish().eval()(x)\n",
    "hswish_result = HSwish().eval()(x)\n",
    "sigmoid_result = nn.Sigmoid().eval()(x)\n",
    "hsigmoid_result = HSigmoid().eval()(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('Swish vs Hard Swish vs ReLU functions (forward)')\n",
    "plt.grid(True)\n",
    "plt.ylabel('F(x)')\n",
    "plt.xlabel('x')\n",
    "plt.plot(x, relu_result, label=\"ReLU\")\n",
    "plt.plot(x, swish_result, label=\"Swish\")\n",
    "plt.plot(x, hswish_result, label=\"h-Swish\")\n",
    "plt.legend(loc='lower right')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('Sigmoid vs Hard Sigmoid functions (forward)')\n",
    "plt.grid(True)\n",
    "plt.ylabel('F(x)')\n",
    "plt.xlabel('x')\n",
    "plt.plot(x, sigmoid_result, label=\"Sigmoid\")\n",
    "plt.plot(x, hsigmoid_result, label=\"h-Sigmoid\")\n",
    "plt.legend(loc='lower right')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Backward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_grad(function, input_x):\n",
    "    x = input_x.clone().requires_grad_(True)\n",
    "    y = function.train()(x)\n",
    "    y = y.sum()\n",
    "    y.backward()\n",
    "    return x.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_grad = torch.from_numpy(np.arange(-8.0, 8.0, 0.25, dtype=np.float32))\n",
    "relu_grad_y = get_grad(nn.ReLU(), x_grad)\n",
    "swish_grad_y = get_grad(Swish(), x_grad)\n",
    "hswish_grad_y = get_grad(HSwish(), x_grad)\n",
    "sigmoid_grad_y = get_grad(nn.Sigmoid(), x_grad)\n",
    "hsigmoid_grad_y = get_grad(HSigmoid(), x_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title('Swish vs Hard Swish vs ReLU functions (backward)')\n",
    "plt.grid(True)\n",
    "plt.ylabel('grad F(x)')\n",
    "plt.xlabel('x')\n",
    "plt.plot(x_grad, relu_grad_y, label=\"ReLU\")\n",
    "plt.plot(x_grad, swish_grad_y, label=\"Swish\")\n",
    "plt.plot(x_grad, hswish_grad_y, label=\"h-Swish\")\n",
    "plt.legend(loc='lower right')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "plt.title('Sigmoid vs Hard Sigmoid functions (backward)')\n",
    "plt.grid(True)\n",
    "plt.ylabel('grad F(x)')\n",
    "plt.xlabel('x')\n",
    "plt.plot(x_grad, sigmoid_grad_y, label=\"Sigmoid\")\n",
    "plt.plot(x_grad, hsigmoid_grad_y, label=\"h-Sigmoid\")\n",
    "plt.legend(loc='lower right')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "list(zip(x_grad, hsigmoid_grad_y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "source": [
    "#### H-swish gradient formula"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "def hswish_grad(x, grad):\n",
    "    if x == -3.0: return 0.0\n",
    "    if x == 3.0: return grad\n",
    "    if x > 3.0: return grad\n",
    "    if x > -3.0 and x < 3.0: return  grad*(x/3.0 + 0.5)\n",
    "    return 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "formula_hswish_grad = [hswish_grad(x, 1) for x in x_grad]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "list(zip(x_grad, hswish_grad_y,formula_hswish_grad))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple arcitecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "from enum import Enum\n",
    "\n",
    "class NetType(Enum):\n",
    "    relu = 0,\n",
    "    swish = 1,\n",
    "    hswish = 2,\n",
    "    sigmoid = 3,\n",
    "    hsigmoid = 4\n",
    "    \n",
    "def trace_forward(name, tensor, batch=0, start=0, stop=-1):\n",
    "    print(f\"{name} ({tensor.shape}), #{batch}[{start}:{stop}]\")\n",
    "    print(*[x.item() for x in tensor[batch][start:stop]])\n",
    "    print(\"-----\")\n",
    "    \n",
    "class Toy(nn.Module):\n",
    "    def __init__(self, activation, n_input, n_hidden, n_output, trace = None, **kwargs):\n",
    "        super().__init__()\n",
    "        self.trace = trace\n",
    "        \n",
    "        self.fc1 = nn.Linear(n_input, n_hidden)\n",
    "        self.activation = activation\n",
    "        self.fc2 = nn.Linear(n_hidden, n_output)\n",
    "        self.softmax = nn.LogSoftmax(dim=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.trace and trace_forward(\"data\", x, **self.trace)\n",
    "        out = self.fc1(x)\n",
    "        self.trace and trace_forward(\"fc1\", out, **self.trace)\n",
    "        out = self.activation(out)\n",
    "        self.trace and trace_forward(\"act\", out, **self.trace)\n",
    "        out = self.fc2(out)\n",
    "        self.trace and trace_forward(\"fc2\", out, **self.trace)\n",
    "        out = self.softmax(out)\n",
    "        self.trace and trace_forward(\"softmax\", out, **self.trace)\n",
    "        return out\n",
    "    \n",
    "def gen_net(net_type: NetType, net_config, device, trace = False):\n",
    "    if net_type == NetType.relu:\n",
    "        return Toy(activation=nn.ReLU(), trace=trace, **net_config).to(device)\n",
    "    if net_type == NetType.swish:\n",
    "        return Toy(activation=Swish(), trace=trace, **net_config).to(device)\n",
    "    if net_type == NetType.hswish:\n",
    "        return Toy(activation=HSwish(), trace=trace, **net_config).to(device)\n",
    "    if net_type == NetType.sigmoid:\n",
    "        return Toy(activation=nn.Sigmoid(), trace=trace, **net_config).to(device)\n",
    "    if net_type == NetType.hsigmoid:\n",
    "        return Toy(activation=HSigmoid(), trace=trace, **net_config).to(device)\n",
    "    raise NotImplementedError(\"Unknown network type\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MNIST Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from typing import Callable, Optional\n",
    "\n",
    "\n",
    "def accuracy(\n",
    "    model: torch.nn.Module,\n",
    "    dataloader: DataLoader,\n",
    "    preprocessor: Optional[Callable] = None,\n",
    "    device: str = \"cpu\",\n",
    "    squeeze_target: bool = False,\n",
    "    **kwargs,\n",
    ") -> float:\n",
    "    \"\"\"\n",
    "    The function returns an accuracy score in percentages.\n",
    "\n",
    "    Accuracy = correct answer / total answers\n",
    "\n",
    "    :param model: Neural network model\n",
    "    :param dataset: Wrapping object that contains data loaders\n",
    "    :param preprocessor: Callable object which is preprocess data\n",
    "    :param kwargs: Other arguments in dictionary\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    correct, total = 0, 0\n",
    "    cnt = 0\n",
    "    with torch.no_grad():\n",
    "        for data, labels in dataloader:\n",
    "            if preprocessor:\n",
    "                data = preprocessor(data)\n",
    "            data = data.to(device)\n",
    "            labels = labels.to(device)\n",
    "            if squeeze_target:\n",
    "                labels = labels.squeeze()\n",
    "            outputs = model(data)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += outputs.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "            cnt += 1\n",
    "    return 100.0 * correct / total\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time \n",
    "from collections import namedtuple\n",
    "from typing import Callable, Optional, List\n",
    "from pyraul.tools.logging import get_fixedwide_str\n",
    "\n",
    "class AverageMeter:\n",
    "    \"\"\"Computes and stores the average and current value\"\"\"\n",
    "    def __init__(self, history: bool = False):\n",
    "        self.use_history = history\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        self.val = 0\n",
    "        self.avg = 0\n",
    "        self.sum = 0\n",
    "        self.count = 0\n",
    "        if self.use_history:\n",
    "            self.history=[]\n",
    "\n",
    "    def update(self, val, n=1):\n",
    "        self.val = val\n",
    "        self.sum += val * n\n",
    "        self.count += n\n",
    "        self.avg = self.sum / self.count\n",
    "        if self.use_history:\n",
    "            self.history.append(val)\n",
    "            \n",
    "            \n",
    "def show_params(model):\n",
    "    print(\"====================================\")\n",
    "    for name, param in model.named_parameters():\n",
    "        if param.requires_grad:\n",
    "            if param.data is not None:\n",
    "                print(f\"{name}, {param.data.shape}\")\n",
    "                data = np.transpose(param.data)\n",
    "                data = data[0] if len(data.shape) > 1 else data\n",
    "                print([x.item() for x in data][:10])\n",
    "            if param.grad is not None:\n",
    "                print(f\"grad of {name}, {param.grad.shape}\")\n",
    "                grad = np.transpose(param.grad)\n",
    "                grad = grad[0] if len(grad.shape) > 1 else grad\n",
    "                print([x.item() for x in grad][:10])\n",
    "    print(\"====================================\")\n",
    "        \n",
    "TrainStepResult = namedtuple(\"TrainStepResult\", [\"loss\", \"time_batch_load\", \"time_batch_full\"])\n",
    "\n",
    "def train_step(train_loader, \n",
    "               model, \n",
    "               criterion, \n",
    "               optimizer, \n",
    "               device, \n",
    "               print_freq=1,\n",
    "               verbose: bool = True,\n",
    "               loss_history: bool = False,\n",
    "               preprocessor: Optional[Callable] = None):\n",
    "    \n",
    "    batch_time = AverageMeter()\n",
    "    data_time = AverageMeter()\n",
    "    losses = AverageMeter(history=loss_history)\n",
    "\n",
    "    model.train()\n",
    "\n",
    "    n = len(train_loader)\n",
    "    n_wide = len(str(n))\n",
    "    \n",
    "    end = time.time()\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        \n",
    "        if preprocessor:\n",
    "            input = preprocessor(input)\n",
    "\n",
    "        # measure data loading time\n",
    "        data_time.update(time.time() - end)\n",
    "\n",
    "        target = target.to(device)\n",
    "        input_var = input.to(device)\n",
    "        target_var = target\n",
    "\n",
    "        # compute output\n",
    "        output = model(input_var)\n",
    "        loss = criterion(output, target_var)\n",
    "\n",
    "        # compute gradient and do SGD step\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        output = output.float()\n",
    "        loss = loss.float()\n",
    "        \n",
    "        losses.update(loss.item(), input.size(0))\n",
    "\n",
    "        # measure elapsed time\n",
    "        batch_time.update(time.time() - end)\n",
    "        end = time.time()\n",
    "    \n",
    "        if verbose and i % print_freq == 0:\n",
    "            print(f\"Step {get_fixedwide_str(str(i), n_wide)}/{n}\\t\"\n",
    "                  f\"Loss: {losses.val:.6f} ({losses.avg:.6f})\\t\"\n",
    "                  f\"Time.step: {batch_time.val:.3f} ({batch_time.avg:.3f})\\t\"\n",
    "                  f\"Time.load: {data_time.val:.3f} ({data_time.avg:.3f})\"\n",
    "                 )\n",
    "    return TrainStepResult(loss=losses, time_batch_load=data_time, time_batch_full=batch_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO: Loading MNIST dataset...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.28\n",
      "Step    0/1200\tLoss: 2.352203 (2.352203)\tTime.step: 0.013 (0.013)\tTime.load: 0.009 (0.009)\n",
      "Step  100/1200\tLoss: 2.276294 (2.296440)\tTime.step: 0.005 (0.006)\tTime.load: 0.004 (0.005)\n",
      "Step  200/1200\tLoss: 2.057391 (2.238303)\tTime.step: 0.006 (0.006)\tTime.load: 0.005 (0.005)\n",
      "Step  300/1200\tLoss: 1.876074 (2.160662)\tTime.step: 0.006 (0.006)\tTime.load: 0.005 (0.005)\n",
      "Step  400/1200\tLoss: 1.765929 (2.063696)\tTime.step: 0.005 (0.006)\tTime.load: 0.004 (0.005)\n",
      "Step  500/1200\tLoss: 1.270914 (1.950968)\tTime.step: 0.005 (0.006)\tTime.load: 0.004 (0.005)\n",
      "Step  600/1200\tLoss: 1.210810 (1.835414)\tTime.step: 0.005 (0.006)\tTime.load: 0.004 (0.005)\n",
      "Step  700/1200\tLoss: 0.949828 (1.733999)\tTime.step: 0.006 (0.006)\tTime.load: 0.005 (0.005)\n",
      "Step  800/1200\tLoss: 0.816559 (1.635761)\tTime.step: 0.007 (0.006)\tTime.load: 0.005 (0.005)\n",
      "Step  900/1200\tLoss: 0.940067 (1.549612)\tTime.step: 0.010 (0.006)\tTime.load: 0.009 (0.005)\n",
      "Step 1000/1200\tLoss: 0.779427 (1.471990)\tTime.step: 0.006 (0.006)\tTime.load: 0.004 (0.005)\n",
      "Step 1100/1200\tLoss: 0.635892 (1.402491)\tTime.step: 0.006 (0.006)\tTime.load: 0.005 (0.005)\n",
      "83.01\n"
     ]
    }
   ],
   "source": [
    "from pyraul.tools.seed import set_seed\n",
    "\n",
    "config = {\n",
    "    \"batch_size\": 50,\n",
    "    \"feature_space_dim\": 784,\n",
    "    \"hidden_layer_dim\": 500,\n",
    "    \"classes_n\": 10,\n",
    "    \"seed\": 0,\n",
    "    \"device\": \"cuda\",\n",
    "    \"epochs\": 50,\n",
    "    \"sgd\": {\"lr\": 0.05}\n",
    "}\n",
    "\n",
    "net_config = {\n",
    "    \"n_input\": config[\"feature_space_dim\"], \n",
    "    \"n_hidden\": config[\"hidden_layer_dim\"], \n",
    "    \"n_output\": config[\"classes_n\"]\n",
    "}\n",
    "\n",
    "set_seed(config[\"seed\"])\n",
    "\n",
    "device = torch.device(config[\"device\"])\n",
    "model = gen_net(NetType.hsigmoid, net_config, device)\n",
    "\n",
    "# dump_weights(model, \"init.txt\")\n",
    "\n",
    "ds= Dataset(\"MNIST\", **config)\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=config[\"sgd\"][\"lr\"])\n",
    "criterion = nn.NLLLoss(reduction=\"mean\")\n",
    "\n",
    "accuracy_before = accuracy(\n",
    "        model=model,\n",
    "        dataloader=ds.test_loader,\n",
    "        preprocessor=lambda images: images.reshape(-1, 28 * 28),\n",
    "        **config,\n",
    ")\n",
    "\n",
    "print(accuracy_before)\n",
    "\n",
    "loss, _, _ = train_step(\n",
    "                    ds.train_loader, \n",
    "                    model,\n",
    "                    criterion,\n",
    "                    optimizer,\n",
    "                    device,\n",
    "                    print_freq=100,\n",
    "                    verbose=True,\n",
    "                    loss_history=True,\n",
    "                    preprocessor=lambda images: images.reshape(-1, 28 * 28),\n",
    "                )\n",
    "accuracy_after = accuracy(\n",
    "    model=model,\n",
    "    dataloader=ds.test_loader,\n",
    "    preprocessor=lambda images: images.reshape(-1, 28 * 28),\n",
    "    **config,\n",
    ")\n",
    "print(accuracy_after)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "const raul::dtVec idealLosses{ 2.352203369140625_dt, 2.2762935161590576_dt, 2.0573911666870117_dt, 1.876073956489563_dt, 1.7659292221069336_dt, 1.2709139585494995_dt, 1.210809588432312_dt, 0.9498279690742493_dt, 0.8165586590766907_dt, 0.9400674700737_dt, 0.7794268131256104_dt, 0.6358923316001892_dt };\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    \"const raul::dtVec idealLosses{\", \n",
    "    \", \".join([f\"{x}_dt\" for x in loss.history[::100]]),\n",
    "    \"};\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train network with toy dataset (binary classification)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"batches\": 10,\n",
    "    \"batch_size\": 4,\n",
    "    \"feature_space_dim\": 16,\n",
    "    \"classes_n\": 2,\n",
    "    \"dataset_offset\": [0.0, 0.5],\n",
    "    \"hidden_layer_dim\": 64,\n",
    "    \"seed\": 0,\n",
    "    \"device\": \"cpu\",\n",
    "    \"epochs\": 50,\n",
    "    \"sgd\": {\"lr\": 0.05}\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "from pyraul.tools.seed import set_seed\n",
    "\n",
    "def generate_toy_dataset(dataset_offset: list,\n",
    "                         classes_n: int,\n",
    "                         feature_space_dim: int,\n",
    "                         batches: int,\n",
    "                         batch_size: int,\n",
    "                         seed: int,\n",
    "                         device: str,\n",
    "                         **kwargs):\n",
    "    \n",
    "    set_seed()\n",
    "    assert len(dataset_offset) == classes_n\n",
    "    \n",
    "    amoutn_of_vectors = batch_size*batches // classes_n\n",
    "    assert amoutn_of_vectors>0\n",
    "    \n",
    "    x_class_list = []\n",
    "    for i in range(classes_n):\n",
    "        _x = torch.randn(amoutn_of_vectors, feature_space_dim, device=device) + dataset_offset[i]\n",
    "        x_class_list.append(_x)\n",
    "\n",
    "    x = torch.cat(x_class_list, dim=0)\n",
    "    \n",
    "    y_class_list = []\n",
    "    for i in range(classes_n):\n",
    "        _y = torch.ones(amoutn_of_vectors, 1, device=device).long() * i\n",
    "        y_class_list.append(_y)\n",
    "        \n",
    "    y = torch.cat(y_class_list, dim=0)\n",
    "\n",
    "    return x, y\n",
    "        \n",
    "x, y = generate_toy_dataset(**config)\n",
    "\n",
    "plt.title('Projection of the feature space')\n",
    "plt.ylabel('x0')\n",
    "plt.xlabel('x1')\n",
    "plt.scatter(x[:,0], x[:, 1], c=y, alpha=0.5)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "def get_ds(full_dataset):\n",
    "    train_size = int(0.8 * len(full_dataset))\n",
    "    test_size = len(full_dataset) - train_size\n",
    "    return torch.utils.data.random_split(full_dataset, [train_size, test_size])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, Dataset\n",
    "device = torch.device(config[\"device\"])\n",
    "\n",
    "net_config = {\n",
    "    \"n_input\": config[\"feature_space_dim\"], \n",
    "    \"n_hidden\": config[\"hidden_layer_dim\"], \n",
    "    \"n_output\": config[\"classes_n\"]\n",
    "}\n",
    "\n",
    "nets = Nets(\n",
    "    relu = Toy(activation=nn.ReLU(), **net_config).to(device),\n",
    "    swish = Toy(activation=Swish(), **net_config).to(device),\n",
    "    hswish = Toy(activation=HSwish(), **net_config).to(device),\n",
    "    sigmoid = Toy(activation=nn.Sigmoid(), **net_config).to(device),\n",
    "    hsigmoid = Toy(activation=HSigmoid(), **net_config).to(device)\n",
    ")\n",
    "\n",
    "model = nets.hsigmoid\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=config[\"sgd\"][\"lr\"])\n",
    "criterion = lambda y,t: nn.NLLLoss()(y, t.squeeze())\n",
    "\n",
    "train_ds, test_ds = get_ds(list(zip(x, y)))\n",
    "ds_train_loader = DataLoader(train_ds, batch_size=config[\"batch_size\"], shuffle=True)\n",
    "ds_test_loader = DataLoader(test_ds, batch_size=config[\"batch_size\"], shuffle=True)\n",
    "\n",
    "print(accuracy(model, ds_test_loader, squeeze_target=True))\n",
    "      \n",
    "history_loss=[]\n",
    "history_acc=[]\n",
    "for epoch in range(config[\"epochs\"]):\n",
    "    loss, _, _ = train_step(\n",
    "                        ds_train_loader, \n",
    "                        model,\n",
    "                        criterion,\n",
    "                        optimizer,\n",
    "                        config[\"device\"],\n",
    "                        print_freq=5,\n",
    "                        verbose=False\n",
    "                    )\n",
    "    acc = accuracy(model, ds_test_loader, squeeze_target=True)\n",
    "    history_acc.append(acc)\n",
    "    history_loss.append(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "plt.grid(True)\n",
    "plt.title('Training')\n",
    "\n",
    "ax2 = ax1.twinx()\n",
    "ax1.plot(history_acc, 'g-')\n",
    "ax2.plot(history_loss, 'r-')\n",
    "\n",
    "ax1.set_xlabel('Epoch')\n",
    "ax1.set_ylabel('Test accuracy', color='g')\n",
    "ax2.set_ylabel('Train Loss avg', color='r')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Converter prototyp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "from pyraul.tools.converter import cvt_model_to_raul\n",
    "cvt_model_to_raul(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyraul.tools.converter import cvt_tensor_to_raul\n",
    "cvt_tensor_to_raul(y.float())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Watermark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext watermark\n",
    "%watermark -d -u -v -iv"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
